{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load and parse data with TensorFlow\n",
    "\n",
    "A TensorFlow example to build input pipelines for loading data efficiently.\n",
    "\n",
    "\n",
    "- Numpy Arrays\n",
    "- Images\n",
    "- CSV file\n",
    "- Custom data from a Generator\n",
    "\n",
    "For more information about creating and loading TensorFlow's `TFRecords` data format, see: [tfrecords.ipynb](tfrecords.ipynb)\n",
    "\n",
    "- Author: Aymeric Damien\n",
    "- Project: https://github.com/aymericdamien/TensorFlow-Examples/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import absolute_import, division, print_function\n",
    "\n",
    "import numpy as np\n",
    "import random\n",
    "import requests\n",
    "import string\n",
    "import tarfile\n",
    "import tensorflow as tf"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load Numpy Arrays\n",
    "\n",
    "Build a data pipeline over numpy arrays."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a toy dataset (even and odd numbers, with respective labels of 0 and 1).\n",
    "evens = np.arange(0, 100, step=2, dtype=np.int32)\n",
    "evens_label = np.zeros(50, dtype=np.int32)\n",
    "odds = np.arange(1, 100, step=2, dtype=np.int32)\n",
    "odds_label = np.ones(50, dtype=np.int32)\n",
    "# Concatenate arrays\n",
    "features = np.concatenate([evens, odds])\n",
    "labels = np.concatenate([evens_label, odds_label])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with tf.Graph().as_default():\n",
    "    # Create TF session.\n",
    "    sess = tf.Session()\n",
    "    \n",
    "    # Slice the numpy arrays (each row becoming a record).\n",
    "    data = tf.data.Dataset.from_tensor_slices((features, labels))\n",
    "    # Refill data indefinitely.  \n",
    "    data = data.repeat()\n",
    "    # Shuffle data.\n",
    "    data = data.shuffle(buffer_size=100)\n",
    "    # Batch data (aggregate records together).\n",
    "    data = data.batch(batch_size=4)\n",
    "    # Prefetch batch (pre-load batch for faster consumption).\n",
    "    data = data.prefetch(buffer_size=1)\n",
    "    \n",
    "    # Create an iterator over the dataset.\n",
    "    iterator = data.make_initializable_iterator()\n",
    "    # Initialize the iterator.\n",
    "    sess.run(iterator.initializer)\n",
    "\n",
    "    # Get next data batch.\n",
    "    d = iterator.get_next()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[82 58 80 23] [0 0 0 1]\n",
      "[16 91 74 96] [0 1 0 0]\n",
      "[ 4 17 32 34] [0 1 0 0]\n",
      "[16  8 77 21] [0 0 1 1]\n",
      "[20 99 48 18] [0 1 0 0]\n"
     ]
    }
   ],
   "source": [
    "# Display data.\n",
    "for i in range(5):\n",
    "    x, y = sess.run(d)\n",
    "    print(x, y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load CSV files\n",
    "\n",
    "Build a data pipeline from features stored in a CSV file. For this example, Titanic dataset will be used as a toy dataset stored in CSV format."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Titanic Dataset\n",
    "\n",
    "\n",
    "\n",
    "survived|pclass|name|sex|age|sibsp|parch|ticket|fare\n",
    "--------|------|----|---|---|-----|-----|------|----\n",
    "1|1|\"Allen, Miss. Elisabeth Walton\"|female|29|0|0|24160|211.3375\n",
    "1|1|\"Allison, Master. Hudson Trevor\"|male|0.9167|1|2|113781|151.5500\n",
    "0|1|\"Allison, Miss. Helen Loraine\"|female|2|1|2|113781|151.5500\n",
    "0|1|\"Allison, Mr. Hudson Joshua Creighton\"|male|30|1|2|113781|151.5500\n",
    "...|...|...|...|...|...|...|...|..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Download Titanic dataset (in csv format).\n",
    "d = requests.get(\"https://raw.githubusercontent.com/tflearn/tflearn.github.io/master/resources/titanic_dataset.csv\")\n",
    "with open(\"titanic_dataset.csv\", \"wb\") as f:\n",
    "    f.write(d.content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load Titanic dataset.\n",
    "# Original features: survived,pclass,name,sex,age,sibsp,parch,ticket,fare\n",
    "# Select specific columns: survived,pclass,name,sex,age,fare\n",
    "column_to_use = [0, 1, 2, 3, 4, 8]\n",
    "record_defaults = [tf.int32, tf.int32, tf.string, tf.string, tf.float32, tf.float32]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with tf.Graph().as_default():\n",
    "    # Create TF session.\n",
    "    sess = tf.Session()\n",
    "    \n",
    "    # Load the whole dataset file, and slice each line.\n",
    "    data = tf.data.experimental.CsvDataset(\"titanic_dataset.csv\", record_defaults, header=True, select_cols=column_to_use)\n",
    "    # Refill data indefinitely.  \n",
    "    data = data.repeat()\n",
    "    # Shuffle data.\n",
    "    data = data.shuffle(buffer_size=1000)\n",
    "    # Batch data (aggregate records together).\n",
    "    data = data.batch(batch_size=2)\n",
    "    # Prefetch batch (pre-load batch for faster consumption).\n",
    "    data = data.prefetch(buffer_size=1)\n",
    "    \n",
    "    # Create an iterator over the dataset.\n",
    "    iterator = data.make_initializable_iterator()\n",
    "    # Initialize the iterator.\n",
    "    sess.run(iterator.initializer)\n",
    "\n",
    "    # Get next data batch.\n",
    "    d = iterator.get_next()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1 0]\n",
      "[3 1]\n",
      "['Lam, Mr. Ali' 'Widener, Mr. Harry Elkins']\n",
      "['male' 'male']\n",
      "[ 0. 27.]\n",
      "[ 56.4958 211.5   ]\n",
      "\n",
      "[0 1]\n",
      "[1 1]\n",
      "['Baumann, Mr. John D' 'Daly, Mr. Peter Denis ']\n",
      "['male' 'male']\n",
      "[ 0. 51.]\n",
      "[25.925 26.55 ]\n",
      "\n",
      "[0 1]\n",
      "[3 1]\n",
      "['Assam, Mr. Ali' 'Newell, Miss. Madeleine']\n",
      "['male' 'female']\n",
      "[23. 31.]\n",
      "[  7.05  113.275]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Display data.\n",
    "for i in range(3):\n",
    "    survived, pclass, name, sex, age, fare = sess.run(d)\n",
    "    print(survived)\n",
    "    print(pclass)\n",
    "    print(name)\n",
    "    print(sex)\n",
    "    print(age)\n",
    "    print(fare)\n",
    "    print(\"\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load Images\n",
    "\n",
    "Build a data pipeline by loading images from disk. For this example, Oxford Flowers dataset will be used."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Download Oxford 17 flowers dataset.\n",
    "d = requests.get(\"http://www.robots.ox.ac.uk/~vgg/data/flowers/17/17flowers.tgz\")\n",
    "with open(\"17flowers.tgz\", \"wb\") as f:\n",
    "    f.write(d.content)\n",
    "# Extract archive.\n",
    "with tarfile.open(\"17flowers.tgz\") as t:\n",
    "    t.extractall()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a file to list all images path and their corresponding label.\n",
    "with open('jpg/dataset.csv', 'w') as f:\n",
    "    c = 0\n",
    "    for i in range(1360):\n",
    "        f.write(\"jpg/image_%04i.jpg,%i\\n\" % (i+1, c))\n",
    "        if (i+1) % 80 == 0:\n",
    "            c += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with tf.Graph().as_default():\n",
    "    \n",
    "    # Load Images.\n",
    "    with open(\"jpg/dataset.csv\") as f:\n",
    "        dataset_file = f.read().splitlines()\n",
    "    \n",
    "    # Create TF session.\n",
    "    sess = tf.Session()\n",
    "\n",
    "    # Load the whole dataset file, and slice each line.\n",
    "    data = tf.data.Dataset.from_tensor_slices(dataset_file)\n",
    "    # Refill data indefinitely.\n",
    "    data = data.repeat()\n",
    "    # Shuffle data.\n",
    "    data = data.shuffle(buffer_size=1000)\n",
    "\n",
    "    # Load and pre-process images.\n",
    "    def load_image(path):\n",
    "        # Read image from path.\n",
    "        image = tf.io.read_file(path)\n",
    "        # Decode the jpeg image to array [0, 255].\n",
    "        image = tf.image.decode_jpeg(image)\n",
    "        # Resize images to a common size of 256x256.\n",
    "        image = tf.image.resize(image, [256, 256])\n",
    "        # Rescale values to [-1, 1].\n",
    "        image = 1. - image / 127.5\n",
    "        return image\n",
    "    # Decode each line from the dataset file.\n",
    "    def parse_records(line):\n",
    "        # File is in csv format: \"image_path,label_id\".\n",
    "        # TensorFlow requires a default value, but it will never be used.\n",
    "        image_path, image_label = tf.io.decode_csv(line, [\"\", 0])\n",
    "        # Apply the function to load images.\n",
    "        image = load_image(image_path)\n",
    "        return image, image_label\n",
    "    # Use 'map' to apply the above functions in parallel.\n",
    "    data = data.map(parse_records, num_parallel_calls=4)\n",
    "\n",
    "    # Batch data (aggregate images-array together).\n",
    "    data = data.batch(batch_size=2)\n",
    "    # Prefetch batch (pre-load batch for faster consumption).\n",
    "    data = data.prefetch(buffer_size=1)\n",
    "    \n",
    "    # Create an iterator over the dataset.\n",
    "    iterator = data.make_initializable_iterator()\n",
    "    # Initialize the iterator.\n",
    "    sess.run(iterator.initializer)\n",
    "\n",
    "    # Get next data batch.\n",
    "    d = iterator.get_next()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[[[ 0.1294117   0.05098033  0.46666664]\n",
      "   [ 0.1368872   0.05098033  0.48909312]\n",
      "   [ 0.0931372   0.0068627   0.46029407]\n",
      "   ...\n",
      "   [ 0.23480386  0.0522058   0.6102941 ]\n",
      "   [ 0.12696075 -0.05416667  0.38063723]\n",
      "   [-0.10024512 -0.28848052  0.10367644]]\n",
      "\n",
      "  [[ 0.04120708 -0.06118262  0.36256123]\n",
      "   [ 0.08009624 -0.02229345  0.41640145]\n",
      "   [ 0.06797445 -0.04132879  0.41923058]\n",
      "   ...\n",
      "   [ 0.2495715   0.06697345  0.6251221 ]\n",
      "   [ 0.12058818 -0.06094813  0.37577546]\n",
      "   [-0.05184889 -0.24009418  0.16777915]]\n",
      "\n",
      "  [[-0.09234071 -0.22738981  0.20484066]\n",
      "   [-0.03100491 -0.17312062  0.2811274 ]\n",
      "   [ 0.01051998 -0.13237214  0.3376838 ]\n",
      "   ...\n",
      "   [ 0.27787983  0.07494056  0.64203525]\n",
      "   [ 0.11533964 -0.09005249  0.3869906 ]\n",
      "   [-0.02704227 -0.23958337  0.19454747]]\n",
      "\n",
      "  ...\n",
      "\n",
      "  [[ 0.07913595 -0.13069856  0.29874384]\n",
      "   [ 0.10140878 -0.09445572  0.35912937]\n",
      "   [ 0.08869672 -0.08415675  0.41446364]\n",
      "   ...\n",
      "   [ 0.25821072  0.22463232  0.69197303]\n",
      "   [ 0.31636214  0.25750512  0.79362744]\n",
      "   [ 0.09552741  0.01709598  0.57395875]]\n",
      "\n",
      "  [[ 0.09019601 -0.12156868  0.3098039 ]\n",
      "   [ 0.17446858 -0.02271283  0.43218917]\n",
      "   [ 0.06583172 -0.10818791  0.39230233]\n",
      "   ...\n",
      "   [ 0.27021956  0.23664117  0.70269513]\n",
      "   [ 0.19560927  0.1385014   0.6740407 ]\n",
      "   [ 0.04364848 -0.03478289  0.5220798 ]]\n",
      "\n",
      "  [[ 0.02830875 -0.18345594  0.24791664]\n",
      "   [ 0.12937105 -0.06781042  0.38709164]\n",
      "   [ 0.01120263 -0.162817    0.33767325]\n",
      "   ...\n",
      "   [ 0.25989532  0.22631687  0.69237083]\n",
      "   [ 0.1200884   0.06298059  0.5985198 ]\n",
      "   [ 0.05961001 -0.01882136  0.53804135]]]\n",
      "\n",
      "\n",
      " [[[ 0.3333333   0.25490195  0.05882347]\n",
      "   [ 0.3333333   0.25490195  0.05882347]\n",
      "   [ 0.3340686   0.24705875  0.03039211]\n",
      "   ...\n",
      "   [-0.5215688  -0.4599266  -0.14632356]\n",
      "   [-0.5100491  -0.47083342 -0.03725493]\n",
      "   [-0.43419123 -0.39497554  0.05992639]]\n",
      "\n",
      "  [[ 0.34117645  0.26274508  0.0666666 ]\n",
      "   [ 0.35646445  0.2630821   0.0744791 ]\n",
      "   [ 0.3632046   0.2548713   0.04384762]\n",
      "   ...\n",
      "   [-0.9210479  -0.84267783 -0.4540485 ]\n",
      "   [-0.9017464  -0.8390626  -0.3507018 ]\n",
      "   [-0.83339334 -0.7632048  -0.2534927 ]]\n",
      "\n",
      "  [[ 0.3646446   0.2706495   0.06678915]\n",
      "   [ 0.37248772  0.27837008  0.07445425]\n",
      "   [ 0.38033658  0.27053267  0.05950326]\n",
      "   ...\n",
      "   [-0.94302344 -0.84222686 -0.30278325]\n",
      "   [-0.91017747 -0.8090074  -0.18615782]\n",
      "   [-0.83437514 -0.7402575  -0.08192408]]\n",
      "\n",
      "  ...\n",
      "\n",
      "  [[ 0.64705884  0.654902    0.67058825]\n",
      "   [ 0.6318321   0.63967526  0.65536153]\n",
      "   [ 0.63128924  0.6391324   0.65481865]\n",
      "   ...\n",
      "   [ 0.6313726   0.57647055  0.51372546]\n",
      "   [ 0.6078431   0.53725487  0.4823529 ]\n",
      "   [ 0.6078431   0.53725487  0.4823529 ]]\n",
      "\n",
      "  [[ 0.654902    0.654902    0.6704657 ]\n",
      "   [ 0.654902    0.654902    0.6704657 ]\n",
      "   [ 0.64778835  0.64778835  0.6492474 ]\n",
      "   ...\n",
      "   [ 0.6392157   0.5843137   0.5215686 ]\n",
      "   [ 0.6393325   0.56874424  0.5138422 ]\n",
      "   [ 0.63106614  0.5604779   0.50557595]]\n",
      "\n",
      "  [[ 0.654902    0.64705884  0.6313726 ]\n",
      "   [ 0.6548728   0.64702964  0.63134336]\n",
      "   [ 0.64705884  0.63210785  0.6377451 ]\n",
      "   ...\n",
      "   [ 0.63244915  0.5775472   0.5148021 ]\n",
      "   [ 0.6698529   0.5992647   0.5443627 ]\n",
      "   [ 0.6545358   0.5839475   0.5290455 ]]]] [5 9]\n"
     ]
    }
   ],
   "source": [
    "# Display data.\n",
    "for i in range(1):\n",
    "    batch_x, batch_y = sess.run(d)\n",
    "    print(batch_x, batch_y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load data from a Generator\n",
    "\n",
    "Build a data pipeline from a custom generator. For this example, a toy generator yielding random string, vector and it is used."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a dummy generator.\n",
    "def generate_features():\n",
    "    # Function to generate a random string.\n",
    "    def random_string(length):\n",
    "        return ''.join(random.choice(string.ascii_letters) for m in xrange(length))\n",
    "    # Return a random string, a random vector, and a random int.\n",
    "    yield random_string(4), np.random.uniform(size=4), random.randint(0, 10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with tf.Graph().as_default():\n",
    "\n",
    "    # Create TF session.\n",
    "    sess = tf.Session()\n",
    "\n",
    "    # Create TF dataset from the generator.\n",
    "    data = tf.data.Dataset.from_generator(generate_features, output_types=(tf.string, tf.float32, tf.int32))\n",
    "    # Refill data indefinitely.\n",
    "    data = data.repeat()\n",
    "    # Shuffle data.\n",
    "    data = data.shuffle(buffer_size=100)\n",
    "    # Batch data (aggregate records together).\n",
    "    data = data.batch(batch_size=4)\n",
    "    # Prefetch batch (pre-load batch for faster consumption).\n",
    "    data = data.prefetch(buffer_size=1)\n",
    "\n",
    "    # Create an iterator over the dataset.\n",
    "    iterator = data.make_initializable_iterator()\n",
    "    # Initialize the iterator.\n",
    "    sess.run(iterator.initializer)\n",
    "\n",
    "    # Get next data batch.\n",
    "    d = iterator.get_next()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['AvCS' 'kAaI' 'QwGX' 'IWOI'] [[0.6096093  0.32192084 0.26622605 0.70250475]\n",
      " [0.72534287 0.7637426  0.19977213 0.74121326]\n",
      " [0.6930984  0.09409562 0.4063325  0.5002103 ]\n",
      " [0.05160935 0.59411395 0.276416   0.98264974]] [1 3 5 6]\n",
      "['EXjS' 'brvx' 'kwNz' 'eFOb'] [[0.34355283 0.26881003 0.70575935 0.7503411 ]\n",
      " [0.9584373  0.27466875 0.27802315 0.9563204 ]\n",
      " [0.19129485 0.07014314 0.0932724  0.20726128]\n",
      " [0.28744072 0.81736153 0.37507302 0.8984588 ]] [1 9 7 0]\n",
      "['vpSa' 'UuqW' 'xaTO' 'milw'] [[0.2942028  0.8228986  0.5793326  0.16651365]\n",
      " [0.28259405 0.599063   0.2922477  0.95071274]\n",
      " [0.23645316 0.00258607 0.06772221 0.7291911 ]\n",
      " [0.12861755 0.31435087 0.576638   0.7333119 ]] [3 5 8 4]\n",
      "['UBBb' 'MUXs' 'nLJB' 'OBGl'] [[0.2677402  0.17931737 0.02607645 0.85898155]\n",
      " [0.58647937 0.727203   0.13329858 0.8898983 ]\n",
      " [0.13872191 0.47390288 0.7061665  0.08478573]\n",
      " [0.3786016  0.22002582 0.91989636 0.45837343]] [ 5  8  0 10]\n",
      "['kiiz' 'bQYG' 'WpUU' 'AuIY'] [[0.74781317 0.13744462 0.9236441  0.63558507]\n",
      " [0.23649399 0.35303807 0.0951511  0.03541444]\n",
      " [0.33599988 0.6906629  0.97166294 0.55850506]\n",
      " [0.90997607 0.5545979  0.43635726 0.9127501 ]] [8 1 4 4]\n"
     ]
    }
   ],
   "source": [
    "# Display data.\n",
    "for i in range(5):\n",
    "    batch_str, batch_vector, batch_int = sess.run(d)\n",
    "    print(batch_str, batch_vector, batch_int)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tf1",
   "language": "python",
   "name": "tf1"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.15+"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
